# encoding:utf-8
import cv2
import numpy as np
import requests
import base64
import glob
import os


def PreprocessInputs(raw_dir, suggest_size, begin_idx):
    if not os.path.exists('./INPUT/'):
        os.mkdir('./INPUT/')
    files = glob.glob(raw_dir + '/*')
    for i in range(len(files)):
        print(files[i])
        im_ = cv2.imread(files[i], -1) # read all channels
        w, h = im_.shape[1], im_.shape[0]
        max_w = suggest_size[0]
        max_h = suggest_size[1]
        if h > w:
            max_w, max_h = max_h, max_w
        if w > max_w:
            h_new = int(h * max_w / w)
            im_ = cv2.resize(im_, (max_w, h_new))
        if h > max_h:
            w_new = int(w * max_h / h)
            im_ = cv2.resize(im_, (w_new, max_h))
        cv2.imwrite('./INPUT/%08d.PNG'%(i+begin_idx), im_)


def GetLatestTokenID(client_key, client_secret):
    host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s' % (client_key, client_secret)
    response = requests.get(host)
    if response:
        res = response.json()
        return res['access_token']


def SegmentHumanBody():
    client_key = 'X18MxolY57DBw7imWYHriRHU'
    client_secret = '9OUeYDWRQKKZkrMmOtZFaHRVu6ZxVqsx'
    request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_seg"
    access_token = GetLatestTokenID(client_key, client_secret)
    
    if not os.path.exists('./OUTPUT/'):
        os.mkdir('./OUTPUT/')
    
    files = glob.glob('./INPUT/*.PNG')
    for i in range(len(files)):
        #if i < 160:
        #    continue
        print(files[i])
        f = open(files[i], 'rb')
        img = base64.b64encode(f.read())
        params = {"image":img}
        request_url = request_url + "?access_token=" + access_token
        headers = {'content-type': 'application/x-www-form-urlencoded'}
        query_done = False
        res = []
        while not query_done:
            response = requests.post(request_url, data=params, headers=headers)
            try:
                if response:
                    res = response.json()
                    query_done = True
            except:
                print('Json parsing failed!')
                print(response)
                # update the token
                request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_seg"
                access_token = GetLatestTokenID(client_key, client_secret)
                request_url = request_url + "?access_token=" + access_token
        foreground = base64.b64decode(res['foreground'])
        nparr = np.fromstring(foreground, np.uint8)
        im_fg = cv2.imdecode(nparr, -1)
        f_new = files[i].replace('INPUT', 'OUTPUT')
        cv2.imwrite(f_new, im_fg)
        
        
if __name__ == '__main__':
    source_dir = './RAW'
    PreprocessInputs(source_dir, [960, 720], 0)
    SegmentHumanBody()
    